"""
Personas Generator for Multiple LLMs

This script generates 300 unique personas using different LLM models:
- Llama 3.1 (8B and 70B)
- Gemma 2 (9B and 27B)
- GPT-4
- Gemini Pro

Usage:
    python persona_generator.py --model llama --model_size 8 --temp 0.7
    python persona_generator.py --model gpt4 --temp 0.7
    python persona_generator.py --model gemini --temp 0.7
"""

import os
import argparse
import csv
import time
import torch
import gc
from datetime import datetime
from tqdm import tqdm
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Configure API clients for OpenAI and Google
try:
    import openai
    openai.api_key = os.getenv('OPENAI_API_KEY')
except ImportError:
    print("OpenAI package not installed. GPT-4 functionality will not be available.")

try:
    from google import generativeai
    generativeai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
except ImportError:
    print("Google Generative AI package not installed. Gemini Pro functionality will not be available.")

# Login to Hugging Face
try:
    from huggingface_hub import login
    login(os.getenv('HF_TOKEN'))
except ImportError:
    print("Hugging Face Hub package not installed. Llama/Gemma functionality will not be available.")

def get_completion_llama_gemma(model_name, model_size, prompt, temperature):
    """Get completion from Llama or Gemma models"""
    try:
        import transformers
        
        # Select the correct model ID
        if model_name == "llama":
            if model_size == 8:
                model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
            elif model_size == 70:
                model_id = "meta-llama/Meta-Llama-3.1-70B-Instruct"
            else:
                raise ValueError(f"Unsupported Llama model size: {model_size}")
        else:  # gemma
            if model_size == 9:
                model_id = "google/gemma-2-9b-it"
            elif model_size == 27:
                model_id = "google/gemma-2-27b-it"
            else:
                raise ValueError(f"Unsupported Gemma model size: {model_size}")
        
        print(f"Loading model: {model_id}")
        
        # Create the pipeline
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            device_map="auto",
            torch_dtype=torch.bfloat16
        )
        
        # Prepare messages
        messages = [
            {"role": "user", "content": prompt},
        ]
        
        print(f"Generating response with temperature={temperature}")
        
        # Set up generation parameters based on temperature
        if temperature == 0:
            outputs = pipeline(
                messages,
                do_sample=False,
                top_p=None,
                temperature=None,
                max_new_tokens=300
            )
        else:
            outputs = pipeline(
                messages,
                do_sample=True,
                temperature=temperature,
                max_new_tokens=300
            )
            
        # Extract the generated text
        result = outputs[0]["generated_text"][-1]["content"]
        print(f"Generated response of length {len(result)}")
        return result
    except Exception as e:
        print(f"Error with {model_name}: {str(e)}")
        import traceback
        traceback.print_exc()
        
        # Try to free up memory in case of CUDA out of memory errors
        if "CUDA out of memory" in str(e):
            try:
                import gc
                import torch
                gc.collect()
                torch.cuda.empty_cache()
                print("Cleared CUDA cache to free up memory")
            except:
                pass
            
        return None

def get_completion_gpt4(prompt, temperature):
    """Get completion from GPT-4"""
    try:
        print("Sending request to OpenAI API")
        response = openai.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "user", "content": prompt}
            ],
            temperature=temperature,
            max_tokens=300
        )
        result = response.choices[0].message.content
        print(f"Received response from OpenAI API")
        return result
    except Exception as e:
        print(f"Error with GPT-4: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

def get_completion_gemini(prompt, temperature):
    """Get completion from Gemini Pro"""
    try:
        print("Sending request to Gemini API")
        model = generativeai.GenerativeModel('gemini-pro')
        generation_config = generativeai.GenerationConfig(
            temperature=temperature,
            max_output_tokens=300,
            top_p=0.95
        )
        response = model.generate_content(
            prompt,
            generation_config=generation_config
        )
        result = response.text
        print(f"Received response from Gemini API")
        return result
    except Exception as e:
        print(f"Error with Gemini Pro: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

def get_completion(args, prompt):
    """Get completion from specified model"""
    if args.model == "llama":
        return get_completion_llama_gemma(args.model, args.model_size, prompt, args.temp)
    elif args.model == "gemma":
        return get_completion_llama_gemma(args.model, args.model_size, prompt, args.temp)
    elif args.model == "gpt4":
        return get_completion_gpt4(prompt, args.temp)
    elif args.model == "gemini":
        return get_completion_gemini(prompt, args.temp)
    else:
        raise ValueError(f"Unknown model: {args.model}")

def run(args):
    """Generate personas with the specified model"""
    # Create output directory
    os.makedirs('./personas', exist_ok=True)
    
    # Format model name for CSV file
    if args.model == "llama":
        model_str = f"{args.model}{args.model_size}b"
    elif args.model == "gemma":
        model_str = f"{args.model}{args.model_size}b"
    elif args.model == "gpt4":
        model_str = args.model
    elif args.model == "gemini":
        model_str = f"{args.model}_pro"
    
    # Format temperature string for filename
    temp_str = str(args.temp).replace('.', '')
    
    # Create CSV filename
    csv_filename = f"./personas/personas_{model_str}_temp{temp_str}.csv"
    
    # Check if file already exists
    if os.path.exists(csv_filename) and not args.force_overwrite:
        raise FileExistsError(f"Output file already exists: {csv_filename}. Use --force_overwrite to overwrite.")
    
    # Create a set to store unique personas
    generated_personas = set()
    
    # Create the prompt
    prompt = "Create a persona (2-3 sentences long): "
    
    print(f"Generating {args.num_personas} personas using {args.model}{' ' + str(args.model_size) if args.model_size else ''}")
    print(f"Temperature: {args.temp}")
    print(f"Output file: {csv_filename}")
    
    # Write the CSV header and results
    with open(csv_filename, "w", newline="", encoding="utf-8") as csvfile:
        csvwriter = csv.writer(csvfile)
        # Write the header
        csvwriter.writerow(["Persona", "Description"])
        
        # Iterate to generate personas
        for i in tqdm(range(args.num_personas)):
            success = False
            retries = 0
            
            while not success and retries < 10:
                try:
                    # Generate a persona
                    response = get_completion(args, prompt)
                    
                    if response is None:
                        print("Failed to get a response, retrying...")
                        retries += 1
                        time.sleep(2)
                        continue
                    
                    # Clean the response
                    response = response.strip()
                    
                    # Check for uniqueness and minimum length
                    if response and len(response) > 20 and response not in generated_personas:
                        generated_personas.add(response)
                        success = True
                    else:
                        if not response or len(response) <= 20:
                            print("Generated too short persona. Retrying...")
                        else:
                            print("Generated duplicate persona. Retrying...")
                        retries += 1
                except Exception as e:
                    print(f"Error generating persona: {str(e)}")
                    retries += 1
                    time.sleep(2)
            
            if success:
                # Write the result to the CSV file
                csvwriter.writerow([i + 1, response])
                
                # Print the result (shortened if too long)
                display_response = response[:100] + "..." if len(response) > 100 else response
                print(f"Persona {i + 1}:")
                print(display_response)
                print("-" * 40)
                
                # Flush CSV file periodically
                if i % 10 == 0:
                    csvfile.flush()
            else:
                print(f"Failed to generate unique persona after {retries} retries for iteration {i+1}.")
            
            # Sleep to respect rate limits for API-based models
            sleep_time = 1
            if args.model in ["gpt4", "gemini"]:
                sleep_time = 3  # Longer sleep for API models
            time.sleep(sleep_time)
    
    print(f"\nGenerated {len(generated_personas)} unique personas")
    print(f"Results saved to: {csv_filename}")

def main():
    """Main entry point for the script"""
    parser = argparse.ArgumentParser(description='Generate personas using various LLM models')
    
    # Model selection
    parser.add_argument('--model', choices=['llama', 'gemma', 'gpt4', 'gemini'], required=True,
                      help='Model type to use (llama, gemma, gpt4, gemini)')
    parser.add_argument('--model_size', type=int, 
                      choices=[8, 70, 9, 27],
                      help='Model size in billions: Llama (8, 70), Gemma (9, 27)')
    
    # Generation parameters
    parser.add_argument('--temp', type=float, default=0.7,
                      help='Temperature for generation (higher = more creative)')
    parser.add_argument('--num_personas', type=int, default=300,
                      help='Number of personas to generate')
    
    # Output settings
    parser.add_argument('--force_overwrite', action='store_true',
                      help='Overwrite existing output file if it exists')
    
    args = parser.parse_args()
    
    # Set default model size if not provided
    if args.model_size is None:
        if args.model == 'llama':
            args.model_size = 8  # Default to 8B for Llama
            print(f"No model size specified, defaulting to {args.model_size}B for {args.model}")
        elif args.model == 'gemma':
            args.model_size = 9  # Default to 9B for Gemma
            print(f"No model size specified, defaulting to {args.model_size}B for {args.model}")
    
    # Validate model size is provided for models that need it
    if args.model in ['llama', 'gemma'] and args.model_size is None:
        parser.error(f"--model_size is required for {args.model}")
    
    # Check for environment variables for API-based models
    if args.model == 'gpt4' and not os.getenv('OPENAI_API_KEY'):
        print("Warning: OPENAI_API_KEY environment variable not set. GPT-4 will not work.")
    
    if args.model == 'gemini' and not os.getenv('GOOGLE_API_KEY'):
        print("Warning: GOOGLE_API_KEY environment variable not set. Gemini Pro will not work.")
    
    if args.model in ['llama', 'gemma'] and not os.getenv('HF_TOKEN'):
        print("Warning: HF_TOKEN environment variable not set. Hugging Face models may not work properly.")
    
    try:
        run(args)
    except Exception as e:
        print(f"Error running persona generation: {str(e)}")
        import traceback
        traceback.print_exc()
        return 1
    
    return 0